{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.00319858, 0.002784  , 0.04810512, 0.03201213], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('CartPole-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAF7CAYAAAD4/3BBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAApj0lEQVR4nO3df3SU5Z3//9dMkhkIYSYNkEwiCaJSIEJQQcOsrUuXlADR1TWer1pWsMuBI5t4qrEW01oVu8e42rP+6CL8sV1xP0dKa7+iKwqKIGGtETAl5Zekwoc2WJgEpZkJaH7NXJ8/WKYdRciEkLkmeT7Ouc/J3Nc197zv6zDJi/u+7vt2GGOMAAAALOJMdAEAAABfREABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANZJaEBZvny5Lr74Yg0ZMkTFxcXavn17IssBAACWSFhA+eUvf6mqqio9/PDD+u1vf6spU6aotLRULS0tiSoJAABYwpGohwUWFxfr6quv1r//+79LkiKRiPLz83X33XfrgQceSERJAADAEqmJ+NDOzk7V19eruro6us7pdKqkpER1dXVf6t/R0aGOjo7o60gkouPHj2vEiBFyOBz9UjMAADg/xhi1tbUpLy9PTufZT+IkJKB88sknCofDysnJiVmfk5Oj/fv3f6l/TU2Nli1b1l/lAQCAC+jw4cMaPXr0WfskJKDEq7q6WlVVVdHXwWBQBQUFOnz4sDweTwIrAwAAPRUKhZSfn6/hw4efs29CAsrIkSOVkpKi5ubmmPXNzc3y+Xxf6u92u+V2u7+03uPxEFAAAEgyPZmekZCreFwul6ZOnapNmzZF10UiEW3atEl+vz8RJQEAAIsk7BRPVVWVFixYoGnTpumaa67R008/rZMnT+q73/1uokoCAACWSFhAufXWW3Xs2DE99NBDCgQCuuKKK7Rhw4YvTZwFAACDT8Lug3I+QqGQvF6vgsEgc1AAAEgS8fz95lk8AADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADW6fOA8sgjj8jhcMQsEyZMiLa3t7eroqJCI0aMUEZGhsrLy9Xc3NzXZQAAgCR2QY6gXH755Tp69Gh0effdd6Nt9957r1577TW99NJLqq2t1ZEjR3TzzTdfiDIAAECSSr0gG01Nlc/n+9L6YDCon//851q9erX+7u/+TpL0/PPPa+LEiXr//fc1ffr0C1EOAABIMhfkCMpHH32kvLw8XXLJJZo3b56ampokSfX19erq6lJJSUm074QJE1RQUKC6urqv3F5HR4dCoVDMAgAABq4+DyjFxcVatWqVNmzYoBUrVujQoUP65je/qba2NgUCAblcLmVmZsa8JycnR4FA4Cu3WVNTI6/XG13y8/P7umwAAGCRPj/FM2fOnOjPRUVFKi4u1pgxY/SrX/1KQ4cO7dU2q6urVVVVFX0dCoUIKQAADGAX/DLjzMxMff3rX9eBAwfk8/nU2dmp1tbWmD7Nzc1nnLNymtvtlsfjiVkAAMDAdcEDyokTJ3Tw4EHl5uZq6tSpSktL06ZNm6LtjY2Nampqkt/vv9ClAACAJNHnp3i+//3v64YbbtCYMWN05MgRPfzww0pJSdHtt98ur9erhQsXqqqqSllZWfJ4PLr77rvl9/u5ggcAAET1eUD5+OOPdfvtt+vTTz/VqFGj9I1vfEPvv/++Ro0aJUl66qmn5HQ6VV5ero6ODpWWluq5557r6zIAAEAScxhjTKKLiFcoFJLX61UwGGQ+CgAASSKev988iwcAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYJ24A8rWrVt1ww03KC8vTw6HQ6+88kpMuzFGDz30kHJzczV06FCVlJToo48+iulz/PhxzZs3Tx6PR5mZmVq4cKFOnDhxXjsCAAAGjrgDysmTJzVlyhQtX778jO1PPPGEnn32Wa1cuVLbtm3TsGHDVFpaqvb29mifefPmae/evdq4caPWrVunrVu3avHixb3fCwAAMKA4jDGm1292OLR27VrddNNNkk4dPcnLy9N9992n73//+5KkYDConJwcrVq1Srfddps+/PBDFRYWaseOHZo2bZokacOGDZo7d64+/vhj5eXlnfNzQ6GQvF6vgsGgPB5Pb8sHAAD9KJ6/3306B+XQoUMKBAIqKSmJrvN6vSouLlZdXZ0kqa6uTpmZmdFwIkklJSVyOp3atm3bGbfb0dGhUCgUswAAgIGrTwNKIBCQJOXk5MSsz8nJibYFAgFlZ2fHtKempiorKyva54tqamrk9XqjS35+fl+WDQAALJMUV/FUV1crGAxGl8OHDye6JAAAcAH1aUDx+XySpObm5pj1zc3N0Tafz6eWlpaY9u7ubh0/fjza54vcbrc8Hk/MAgAABq4+DShjx46Vz+fTpk2boutCoZC2bdsmv98vSfL7/WptbVV9fX20z+bNmxWJRFRcXNyX5QAAgCSVGu8bTpw4oQMHDkRfHzp0SA0NDcrKylJBQYHuuece/cu//IvGjRunsWPH6sc//rHy8vKiV/pMnDhRs2fP1qJFi7Ry5Up1dXWpsrJSt912W4+u4AEAAANf3AHlgw8+0Le+9a3o66qqKknSggULtGrVKv3gBz/QyZMntXjxYrW2tuob3/iGNmzYoCFDhkTf8+KLL6qyslIzZ86U0+lUeXm5nn322T7YHQAAMBCc131QEoX7oAAAkHwSdh8UAACAvkBAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgnbgDytatW3XDDTcoLy9PDodDr7zySkz7nXfeKYfDEbPMnj07ps/x48c1b948eTweZWZmauHChTpx4sR57QgAABg44g4oJ0+e1JQpU7R8+fKv7DN79mwdPXo0uvziF7+IaZ83b5727t2rjRs3at26ddq6dasWL14cf/UAAGBASo33DXPmzNGcOXPO2sftdsvn852x7cMPP9SGDRu0Y8cOTZs2TZL0s5/9THPnztVPf/pT5eXlxVsSAAAYYC7IHJQtW7YoOztb48eP15IlS/Tpp59G2+rq6pSZmRkNJ5JUUlIip9Opbdu2nXF7HR0dCoVCMQsAABi4+jygzJ49W//1X/+lTZs26V//9V9VW1urOXPmKBwOS5ICgYCys7Nj3pOamqqsrCwFAoEzbrOmpkZerze65Ofn93XZAADAInGf4jmX2267Lfrz5MmTVVRUpEsvvVRbtmzRzJkze7XN6upqVVVVRV+HQiFCCgAAA9gFv8z4kksu0ciRI3XgwAFJks/nU0tLS0yf7u5uHT9+/Cvnrbjdbnk8npgFAAAMXBc8oHz88cf69NNPlZubK0ny+/1qbW1VfX19tM/mzZsViURUXFx8ocsBAABJIO5TPCdOnIgeDZGkQ4cOqaGhQVlZWcrKytKyZctUXl4un8+ngwcP6gc/+IEuu+wylZaWSpImTpyo2bNna9GiRVq5cqW6urpUWVmp2267jSt4AACAJMlhjDHxvGHLli361re+9aX1CxYs0IoVK3TTTTdp586dam1tVV5enmbNmqWf/OQnysnJifY9fvy4Kisr9dprr8npdKq8vFzPPvusMjIyelRDKBSS1+tVMBjkdA8AAEkinr/fcQcUGxBQAABIPvH8/eZZPAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgnbgfFggAF1rnyVb9Yev/OWsfZ6pLl5YslsPh6KeqAPQnAgoAqxhj1N1+UsGm3Wftl+Ia2k8VAUgETvEAsE6kqz3RJQBIMAIKAOuEuzoSXQKABCOgALBOmCMowKBHQAFgHQIKAAIKAOswBwUAAQWAdZiDAoCAAsAyRhECCjDoEVAA2MVIwaY95+zmGV3YD8UASBQCCgDLGH12/ONz9srIuaQfagGQKAQUAEmJO8kCAxsBBUBSSklzJ7oEABcQAQVAUnKmDkl0CQAuIAIKgKTkdBFQgIGMgAIgKaWkEVCAgYyAAiApEVCAgY2AAsAqpof9mCQLDGwEFABWMeHuHvVzpKTI4XBc4GoAJAoBBYBVeA4PACnOgFJTU6Orr75aw4cPV3Z2tm666SY1NjbG9Glvb1dFRYVGjBihjIwMlZeXq7m5OaZPU1OTysrKlJ6eruzsbN1///3q7u7Z/5oADGynnmTc0xM9AAaquAJKbW2tKioq9P7772vjxo3q6urSrFmzdPLkyWife++9V6+99ppeeukl1dbW6siRI7r55puj7eFwWGVlZers7NR7772nF154QatWrdJDDz3Ud3sFIGmFu9oTXQIACziMMb3+r8qxY8eUnZ2t2tpaXXfddQoGgxo1apRWr16tW265RZK0f/9+TZw4UXV1dZo+fbrWr1+v66+/XkeOHFFOTo4kaeXKlVq6dKmOHTsml8t1zs8NhULyer0KBoPyeDy9LR+AhdqOfqT9r/1UOsevpivvfEqp7mH9VBWAvhDP3+/zmoMSDAYlSVlZWZKk+vp6dXV1qaSkJNpnwoQJKigoUF1dnSSprq5OkydPjoYTSSotLVUoFNLevXvP+DkdHR0KhUIxC4CBKdzZzhkeAL0PKJFIRPfcc4+uvfZaTZo0SZIUCATkcrmUmZkZ0zcnJ0eBQCDa56/Dyen2021nUlNTI6/XG13y8/N7WzYAy4W7mSQL4DwCSkVFhfbs2aM1a9b0ZT1nVF1drWAwGF0OHz58wT8TQGJEuIoHgKTU3rypsrJS69at09atWzV69Ojoep/Pp87OTrW2tsYcRWlubpbP54v22b59e8z2Tl/lc7rPF7ndbrnd3JQJGAw6TxzXuc7xpLjSJXEPFGAgi+sIijFGlZWVWrt2rTZv3qyxY8fGtE+dOlVpaWnatGlTdF1jY6Oamprk9/slSX6/X7t371ZLS0u0z8aNG+XxeFRYWHg++wJgAPj093Xn7DNi3HQ5U9L6oRoAiRLXEZSKigqtXr1ar776qoYPHx6dM+L1ejV06FB5vV4tXLhQVVVVysrKksfj0d133y2/36/p06dLkmbNmqXCwkLdcccdeuKJJxQIBPTggw+qoqKCoyQAesSZ5pa4iywwoMUVUFasWCFJmjFjRsz6559/Xnfeeack6amnnpLT6VR5ebk6OjpUWlqq5557Lto3JSVF69at05IlS+T3+zVs2DAtWLBAjz766PntCYBBIyXt3LcjAJDczus+KInCfVCAgWvX6h+qo+2Ts/bJ99+qnEkz5HCm9FNVAPpCv90HBQASgScZAwMfAQVA0nGmucVVPMDARkABkHScqS4myQIDHAEFgDV6OiWOUzzAwEdAAWAPY3r0GB4H90ABBjwCCgBrhLs7z/kUY+nU7BMHp3iAAY2AAsAaprtTUiTRZQCwAAEFgDXC3R09nocCYGAjoACwRqS7q0eneAAMfAQUANaIcAQFwP8ioACwRqSHk2QBDHwEFADWiHR3ST260BjAQEdAAWCNU6d4uIoHAAEFgEVa//A7hTs+P2sfz0UTlTYss38KApAwBBQA1gh3tetcp3hSh2TIyZ1kgQGPgAIgqThS0iQHv7qAgY5vOYCk4kxNk8PJry5goONbDiCpOFPS5OAICjDg8S0HkFScqWkSR1CAAY9vOYCk4uAICjAo8C0HkFScqQQUYDDgWw7ACqdu0Hbuu8g6namSw3HhCwKQUAQUAFYwkbBMpAd3kXU45CCgAAMeAQWAFSLh7p4FFACDAgEFgBVMuEsmEk50GQAsQUABYIVTR1AIKABOIaAAsIIJd8sYAgqAUwgoAKxgOIIC4K8QUABYgUmyAP4aAQWAFbrbQ4p0fn7WPs5Ul1LShvRTRQASiYACwAonj/1RHW2fnLXPkEyfhmZd1E8VAUikuAJKTU2Nrr76ag0fPlzZ2dm66aab1NjYGNNnxowZcvzvjZROL3fddVdMn6amJpWVlSk9PV3Z2dm6//771d3dff57A2BAczhT5HCmJLoMAP0gNZ7OtbW1qqio0NVXX63u7m798Ic/1KxZs7Rv3z4NGzYs2m/RokV69NFHo6/T09OjP4fDYZWVlcnn8+m9997T0aNHNX/+fKWlpemxxx7rg10CMFA5nClypMT1awtAkorrm75hw4aY16tWrVJ2drbq6+t13XXXRdenp6fL5/OdcRtvvfWW9u3bp7fffls5OTm64oor9JOf/ERLly7VI488IpfL1YvdADAYOJypp57FA2DAO685KMFgUJKUlZUVs/7FF1/UyJEjNWnSJFVXV+uzzz6LttXV1Wny5MnKycmJristLVUoFNLevXvP+DkdHR0KhUIxC4DBx5HCERRgsOj1Nz0Sieiee+7Rtddeq0mTJkXXf+c739GYMWOUl5enXbt2aenSpWpsbNTLL78sSQoEAjHhRFL0dSAQOONn1dTUaNmyZb0tFcAAceoUD3NQgMGg1wGloqJCe/bs0bvvvhuzfvHixdGfJ0+erNzcXM2cOVMHDx7UpZde2qvPqq6uVlVVVfR1KBRSfn5+7woHkLSczhRO8QCDRK9O8VRWVmrdunV65513NHr06LP2LS4uliQdOHBAkuTz+dTc3BzT5/Trr5q34na75fF4YhYAg8+pq3gIKMBgEFdAMcaosrJSa9eu1ebNmzV27NhzvqehoUGSlJubK0ny+/3avXu3Wlpaon02btwoj8ejwsLCeMoBMEAYY3rUz+FMZQ4KMEjE9U2vqKjQ6tWr9eqrr2r48OHROSNer1dDhw7VwYMHtXr1as2dO1cjRozQrl27dO+99+q6665TUVGRJGnWrFkqLCzUHXfcoSeeeEKBQEAPPvigKioq5Ha7+34PAdjPmJ49h+d/760EYOCL6wjKihUrFAwGNWPGDOXm5kaXX/7yl5Ikl8ult99+W7NmzdKECRN03333qby8XK+99lp0GykpKVq3bp1SUlLk9/v1j//4j5o/f37MfVMADC7GhBXp7kp0GQAsEtcRlHMdhs3Pz1dtbe05tzNmzBi98cYb8Xw0gAHMRCKKhLmbNIC/4Fk8ABLORMIyYY6gAPgLAgqAhDMmoggBBcBfIaAASLxIWIZTPAD+CgEFQMKdmoPCERQAf0FAAZBwxoQJKABiEFAAJFx78Jja/rT/rH1ShwxXZkFRP1UEINEIKAASz0TOeaM2h9OpFPfQfioIQKIRUAAkB4dTzlRXoqsA0E8IKACSgsPhlDMlLdFlAOgnBBQAScHhcHAEBRhECCgAkgNHUIBBhYACICk4nA450ziCAgwWBBQAycHhlDOFgAIMFgQUAEnBwVU8wKBCQAGQUMYYmUjPnsPjcPIrCxgs+LYDSDCjcFdnoosAYBkCCoDEMkaRbgIKgFgEFAAJZYxRpLsj0WUAsAwBBUCCcYoHwJelJroAAMktHA7LGNPr90fC3erubD9nP2Ok7u6eTaY9E6fTKSeTbIGkQUABcF7Ky8v1+uuv9/r9bleqFpVdqe/MnHTWfnv27NZV83r/NONHHnlEP/rRj3r9fgD9i4AC4LyEw+HzOrKR6pRcqec+smGMOa/PCYfDvX4vgP5HQAGQUO60FE2+JFuS1No1Sn/uzlF3xC2X8zONdP1Jw1JCikSMdv/flgRXCqA/EVAAJFRaaorG54/UkY5LdfCzK/VZeLgiSlWKo0sfdwQ1KWOr0tWihgOBRJcKoB8xYwxAwn3SeZH2nvimToSzFFGaJIfCxqVQ9yjtCJapPZKhzzu6El0mgH5EQAGQUB2RdO0IzVW3OfNzdrrMEG398/+nk+0EFGAwIaAAsIDjrK3GSJ939H6CLIDkQ0ABYD0jcYoHGGQIKACSwmcEFGBQIaAASCi383NdOfwtOXTm+5Q41a1rvf8/R1CAQSaugLJixQoVFRXJ4/HI4/HI7/dr/fr10fb29nZVVFRoxIgRysjIUHl5uZqbm2O20dTUpLKyMqWnpys7O1v333//ed18CUCyM8px/UGXZ7yrIc42OdQtycipLqU7gyr2rtOwlFaOoACDTFz3QRk9erQef/xxjRs3TsYYvfDCC7rxxhu1c+dOXX755br33nv1+uuv66WXXpLX61VlZaVuvvlm/eY3v5F06k6OZWVl8vl8eu+993T06FHNnz9faWlpeuyxxy7IDgKwW3tnt179zX5J+3W8a7s+6RytTjNEQ5wnlOP6g/6c+md1d0fU1R1JdKkA+pHDnM9TviRlZWXpySef1C233KJRo0Zp9erVuuWWWyRJ+/fv18SJE1VXV6fp06dr/fr1uv7663XkyBHl5ORIklauXKmlS5fq2LFjcrnOfJnhF4VCIXm9Xt155509fg+AC2PDhg1qampKdBnnNG3aNF111VWJLgMY1Do7O7Vq1SoFg0F5PJ6z9u31nWTD4bBeeuklnTx5Un6/X/X19erq6lJJSUm0z4QJE1RQUBANKHV1dZo8eXI0nEhSaWmplixZor179+rKK68842d1dHSoo6Mj+joUCkmS7rjjDmVkZPR2FwD0gX379iVFQLnqqqu0cOHCRJcBDGonTpzQqlWretQ37oCye/du+f1+tbe3KyMjQ2vXrlVhYaEaGhrkcrmUmZkZ0z8nJ0eBwKlbVAcCgZhwcrr9dNtXqamp0bJly760ftq0aedMYAAurC9+52110UUX6Zprrkl0GcCgdvoAQ0/EfRXP+PHj1dDQoG3btmnJkiVasGCB9u3bF+9m4lJdXa1gMBhdDh8+fEE/DwAAJFbcR1BcLpcuu+wySdLUqVO1Y8cOPfPMM7r11lvV2dmp1tbWmP9RNTc3y+fzSZJ8Pp+2b98es73TV/mc7nMmbrdbbrc73lIBAECSOu/7oEQiEXV0dGjq1KlKS0vTpk2bom2NjY1qamqS3++XJPn9fu3evVstLX95bPrGjRvl8XhUWFh4vqUAAIABIq4jKNXV1ZozZ44KCgrU1tam1atXa8uWLXrzzTfl9Xq1cOFCVVVVKSsrSx6PR3fffbf8fr+mT58uSZo1a5YKCwt1xx136IknnlAgENCDDz6oiooKjpAAAICouAJKS0uL5s+fr6NHj8rr9aqoqEhvvvmmvv3tb0uSnnrqKTmdTpWXl6ujo0OlpaV67rnnou9PSUnRunXrtGTJEvn9fg0bNkwLFizQo48+2rd7BQAAklpcAeXnP//5WduHDBmi5cuXa/ny5V/ZZ8yYMXrjjTfi+VgAADDI8CweAABgHQIKAACwDgEFAABYh4ACAACs0+tn8QCAJE2fPl2pqfb/KpkwYUKiSwAQh/N+mnEinH6acU+ehggAAOwQz99vTvEAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWiSugrFixQkVFRfJ4PPJ4PPL7/Vq/fn20fcaMGXI4HDHLXXfdFbONpqYmlZWVKT09XdnZ2br//vvV3d3dN3sDAAAGhNR4Oo8ePVqPP/64xo0bJ2OMXnjhBd14443auXOnLr/8cknSokWL9Oijj0bfk56eHv05HA6rrKxMPp9P7733no4ePar58+crLS1Njz32WB/tEgAASHYOY4w5nw1kZWXpySef1MKFCzVjxgxdccUVevrpp8/Yd/369br++ut15MgR5eTkSJJWrlyppUuX6tixY3K5XD36zFAoJK/Xq2AwKI/Hcz7lAwCAfhLP3+9ez0EJh8Nas2aNTp48Kb/fH13/4osvauTIkZo0aZKqq6v12WefRdvq6uo0efLkaDiRpNLSUoVCIe3du/crP6ujo0OhUChmAQAAA1dcp3gkaffu3fL7/Wpvb1dGRobWrl2rwsJCSdJ3vvMdjRkzRnl5edq1a5eWLl2qxsZGvfzyy5KkQCAQE04kRV8HAoGv/MyamhotW7Ys3lIBAECSijugjB8/Xg0NDQoGg/r1r3+tBQsWqLa2VoWFhVq8eHG03+TJk5Wbm6uZM2fq4MGDuvTSS3tdZHV1taqqqqKvQ6GQ8vPze709AABgt7hP8bhcLl122WWaOnWqampqNGXKFD3zzDNn7FtcXCxJOnDggCTJ5/Opubk5ps/p1z6f7ys/0+12R68cOr0AAICB67zvgxKJRNTR0XHGtoaGBklSbm6uJMnv92v37t1qaWmJ9tm4caM8Hk/0NBEAAEBcp3iqq6s1Z84cFRQUqK2tTatXr9aWLVv05ptv6uDBg1q9erXmzp2rESNGaNeuXbr33nt13XXXqaioSJI0a9YsFRYW6o477tATTzyhQCCgBx98UBUVFXK73RdkBwEAQPKJK6C0tLRo/vz5Onr0qLxer4qKivTmm2/q29/+tg4fPqy3335bTz/9tE6ePKn8/HyVl5frwQcfjL4/JSVF69at05IlS+T3+zVs2DAtWLAg5r4pAAAA530flETgPigAACSffrkPCgAAwIVCQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArJOa6AJ6wxgjSQqFQgmuBAAA9NTpv9un/46fTVIGlLa2NklSfn5+gisBAADxamtrk9frPWsfh+lJjLFMJBJRY2OjCgsLdfjwYXk8nkSXlLRCoZDy8/MZxz7AWPYdxrJvMI59h7HsG8YYtbW1KS8vT07n2WeZJOURFKfTqYsuukiS5PF4+MfSBxjHvsNY9h3Gsm8wjn2HsTx/5zpychqTZAEAgHUIKAAAwDpJG1Dcbrcefvhhud3uRJeS1BjHvsNY9h3Gsm8wjn2Hsex/STlJFgAADGxJewQFAAAMXAQUAABgHQIKAACwDgEFAABYJykDyvLly3XxxRdryJAhKi4u1vbt2xNdknW2bt2qG264QXl5eXI4HHrllVdi2o0xeuihh5Sbm6uhQ4eqpKREH330UUyf48ePa968efJ4PMrMzNTChQt14sSJftyLxKupqdHVV1+t4cOHKzs7WzfddJMaGxtj+rS3t6uiokIjRoxQRkaGysvL1dzcHNOnqalJZWVlSk9PV3Z2tu6//351d3f3564k1IoVK1RUVBS9yZXf79f69euj7Yxh7z3++ONyOBy65557ousYz5555JFH5HA4YpYJEyZE2xnHBDNJZs2aNcblcpn//M//NHv37jWLFi0ymZmZprm5OdGlWeWNN94wP/rRj8zLL79sJJm1a9fGtD/++OPG6/WaV155xfzud78zf//3f2/Gjh1rPv/882if2bNnmylTppj333/f/M///I+57LLLzO23397Pe5JYpaWl5vnnnzd79uwxDQ0NZu7cuaagoMCcOHEi2ueuu+4y+fn5ZtOmTeaDDz4w06dPN3/zN38Tbe/u7jaTJk0yJSUlZufOneaNN94wI0eONNXV1YnYpYT47//+b/P666+b3//+96axsdH88Ic/NGlpaWbPnj3GGMawt7Zv324uvvhiU1RUZL73ve9F1zOePfPwww+byy+/3Bw9ejS6HDt2LNrOOCZW0gWUa665xlRUVERfh8Nhk5eXZ2pqahJYld2+GFAikYjx+XzmySefjK5rbW01brfb/OIXvzDGGLNv3z4jyezYsSPaZ/369cbhcJg//elP/Va7bVpaWowkU1tba4w5NW5paWnmpZdeivb58MMPjSRTV1dnjDkVFp1OpwkEAtE+K1asMB6Px3R0dPTvDljka1/7mvmP//gPxrCX2trazLhx48zGjRvN3/7t30YDCuPZcw8//LCZMmXKGdsYx8RLqlM8nZ2dqq+vV0lJSXSd0+lUSUmJ6urqElhZcjl06JACgUDMOHq9XhUXF0fHsa6uTpmZmZo2bVq0T0lJiZxOp7Zt29bvNdsiGAxKkrKysiRJ9fX16urqihnLCRMmqKCgIGYsJ0+erJycnGif0tJShUIh7d27tx+rt0M4HNaaNWt08uRJ+f1+xrCXKioqVFZWFjNuEv8m4/XRRx8pLy9Pl1xyiebNm6empiZJjKMNkuphgZ988onC4XDMPwZJysnJ0f79+xNUVfIJBAKSdMZxPN0WCASUnZ0d056amqqsrKxon8EmEononnvu0bXXXqtJkyZJOjVOLpdLmZmZMX2/OJZnGuvTbYPF7t275ff71d7eroyMDK1du1aFhYVqaGhgDOO0Zs0a/fa3v9WOHTu+1Ma/yZ4rLi7WqlWrNH78eB09elTLli3TN7/5Te3Zs4dxtEBSBRQgkSoqKrRnzx69++67iS4lKY0fP14NDQ0KBoP69a9/rQULFqi2tjbRZSWdw4cP63vf+542btyoIUOGJLqcpDZnzpzoz0VFRSouLtaYMWP0q1/9SkOHDk1gZZCS7CqekSNHKiUl5UuzqJubm+Xz+RJUVfI5PVZnG0efz6eWlpaY9u7ubh0/fnxQjnVlZaXWrVund955R6NHj46u9/l86uzsVGtra0z/L47lmcb6dNtg4XK5dNlll2nq1KmqqanRlClT9MwzzzCGcaqvr1dLS4uuuuoqpaamKjU1VbW1tXr22WeVmpqqnJwcxrOXMjMz9fWvf10HDhzg36UFkiqguFwuTZ06VZs2bYqui0Qi2rRpk/x+fwIrSy5jx46Vz+eLGcdQKKRt27ZFx9Hv96u1tVX19fXRPps3b1YkElFxcXG/15woxhhVVlZq7dq12rx5s8aOHRvTPnXqVKWlpcWMZWNjo5qammLGcvfu3TGBb+PGjfJ4PCosLOyfHbFQJBJRR0cHYxinmTNnavfu3WpoaIgu06ZN07x586I/M569c+LECR08eFC5ubn8u7RBomfpxmvNmjXG7XabVatWmX379pnFixebzMzMmFnUODXDf+fOnWbnzp1Gkvm3f/s3s3PnTvPHP/7RGHPqMuPMzEzz6quvml27dpkbb7zxjJcZX3nllWbbtm3m3XffNePGjRt0lxkvWbLEeL1es2XLlphLET/77LNon7vuussUFBSYzZs3mw8++MD4/X7j9/uj7acvRZw1a5ZpaGgwGzZsMKNGjRpUlyI+8MADpra21hw6dMjs2rXLPPDAA8bhcJi33nrLGMMYnq+/vorHGMazp+677z6zZcsWc+jQIfOb3/zGlJSUmJEjR5qWlhZjDOOYaEkXUIwx5mc/+5kpKCgwLpfLXHPNNeb9999PdEnWeeedd4ykLy0LFiwwxpy61PjHP/6xycnJMW6328ycOdM0NjbGbOPTTz81t99+u8nIyDAej8d897vfNW1tbQnYm8Q50xhKMs8//3y0z+eff27++Z//2Xzta18z6enp5h/+4R/M0aNHY7bzhz/8wcyZM8cMHTrUjBw50tx3332mq6urn/cmcf7pn/7JjBkzxrhcLjNq1Cgzc+bMaDgxhjE8X18MKIxnz9x6660mNzfXuFwuc9FFF5lbb73VHDhwINrOOCaWwxhjEnPsBgAA4MySag4KAAAYHAgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALDO/wPVnRFvxr2W1AAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "这个游戏的状态用4个数字表示,我也不知道这4个数字分别是什么意思,反正这4个数字就能描述游戏全部的状态\n",
      "state= [-0.00165764  0.02087691  0.017921    0.02388013]\n",
      "这个游戏一共有2个动作,不是0就是1\n",
      "env.action_space= Discrete(2)\n",
      "随机一个动作\n",
      "action= 0\n",
      "执行一个动作,得到下一个状态,奖励,是否结束\n",
      "state= [-0.00124011 -0.1744974   0.0183986   0.32216302]\n",
      "reward= 1.0\n",
      "over= False\n"
     ]
    }
   ],
   "source": [
    "#测试游戏环境\n",
    "def test_env():\n",
    "    state = env.reset()\n",
    "    print('这个游戏的状态用4个数字表示,我也不知道这4个数字分别是什么意思,反正这4个数字就能描述游戏全部的状态')\n",
    "    print('state=', state)\n",
    "    #state= [ 0.03490619  0.04873464  0.04908862 -0.00375859]\n",
    "\n",
    "    print('这个游戏一共有2个动作,不是0就是1')\n",
    "    print('env.action_space=', env.action_space)\n",
    "    #env.action_space= Discrete(2)\n",
    "\n",
    "    print('随机一个动作')\n",
    "    action = env.action_space.sample()\n",
    "    print('action=', action)\n",
    "    #action= 1\n",
    "\n",
    "    print('执行一个动作,得到下一个状态,奖励,是否结束')\n",
    "    state, reward, over, _ = env.step(action)\n",
    "\n",
    "    print('state=', state)\n",
    "    #state= [ 0.02018229 -0.16441101  0.01547085  0.2661691 ]\n",
    "\n",
    "    print('reward=', reward)\n",
    "    #reward= 1.0\n",
    "\n",
    "    print('over=', over)\n",
    "    #over= False\n",
    "\n",
    "\n",
    "test_env()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Linear(in_features=4, out_features=128, bias=True)\n",
       "  (1): ReLU()\n",
       "  (2): Linear(in_features=128, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "#计算动作的模型,也是真正要用的模型\n",
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 2),\n",
    ")\n",
    "\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "\n",
    "#得到一个动作\n",
    "def get_action(state):\n",
    "    if random.random() < 0.01:\n",
    "        return random.choice([0, 1])\n",
    "\n",
    "    #走神经网络,得到一个动作\n",
    "    state = torch.FloatTensor(state).reshape(1, 4)\n",
    "\n",
    "    return model(state).argmax().item()\n",
    "\n",
    "\n",
    "get_action([0.0013847, -0.01194451, 0.04260966, 0.00688801])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((208, 0), 208)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#样本池\n",
    "datas = []\n",
    "\n",
    "\n",
    "#向样本池中添加N条数据,删除M条最古老的数据\n",
    "def update_data():\n",
    "    old_count = len(datas)\n",
    "\n",
    "    #玩到新增了N个数据为止\n",
    "    while len(datas) - old_count < 200:\n",
    "        #初始化游戏\n",
    "        state = env.reset()\n",
    "\n",
    "        #玩到游戏结束为止\n",
    "        over = False\n",
    "        while not over:\n",
    "            #根据当前状态得到一个动作\n",
    "            action = get_action(state)\n",
    "\n",
    "            #执行动作,得到反馈\n",
    "            next_state, reward, over, _ = env.step(action)\n",
    "\n",
    "            #记录数据样本\n",
    "            datas.append((state, action, reward, next_state, over))\n",
    "\n",
    "            #更新游戏状态,开始下一个动作\n",
    "            state = next_state\n",
    "\n",
    "    update_count = len(datas) - old_count\n",
    "    drop_count = max(len(datas) - 10000, 0)\n",
    "\n",
    "    #数据上限,超出时从最古老的开始删除\n",
    "    while len(datas) > 10000:\n",
    "        datas.pop(0)\n",
    "\n",
    "    return update_count, drop_count\n",
    "\n",
    "\n",
    "update_data(), len(datas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1222/1396937354.py:7: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  ../torch/csrc/utils/tensor_new.cpp:201.)\n",
      "  state = torch.FloatTensor([i[0] for i in samples])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([[-0.0025, -0.9491,  0.0061,  1.3563],\n",
       "         [-0.1198, -1.3959,  0.1745,  2.2377],\n",
       "         [ 0.0128, -0.1778,  0.0111,  0.3327],\n",
       "         [-0.0972, -1.1807,  0.0682,  1.7701],\n",
       "         [-0.0101, -0.1841,  0.0268,  0.2873]]),\n",
       " tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),\n",
       " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "         1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),\n",
       " tensor([[-0.0214, -1.1443,  0.0332,  1.6509],\n",
       "         [-0.1477, -1.5922,  0.2193,  2.5787],\n",
       "         [ 0.0092, -0.3731,  0.0177,  0.6289],\n",
       "         [-0.1208, -1.3765,  0.1036,  2.0831],\n",
       "         [-0.0138, -0.3796,  0.0325,  0.5883]]),\n",
       " tensor([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,\n",
       "         0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#获取一批数据样本\n",
    "def get_sample():\n",
    "    #从样本池中采样\n",
    "    samples = random.sample(datas, 64)\n",
    "\n",
    "    #[b, 4]\n",
    "    state = torch.FloatTensor([i[0] for i in samples])\n",
    "    #[b]\n",
    "    action = torch.LongTensor([i[1] for i in samples])\n",
    "    #[b]\n",
    "    reward = torch.FloatTensor([i[2] for i in samples])\n",
    "    #[b, 4]\n",
    "    next_state = torch.FloatTensor([i[3] for i in samples])\n",
    "    #[b]\n",
    "    over = torch.LongTensor([i[4] for i in samples])\n",
    "\n",
    "    return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "state[:5], action, reward, next_state[:5], over"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.1012, -0.2599,  0.0318, -0.1834,  0.0314,  0.0534,  0.0356, -0.0098,\n",
       "         0.0410,  0.0453, -0.2864, -0.1689, -0.0071,  0.0648, -0.1902, -0.2391,\n",
       "         0.0651, -0.2211,  0.0381, -0.1216, -0.1402, -0.1801,  0.0347, -0.2975,\n",
       "         0.0560,  0.0296, -0.2869, -0.1836, -0.2856, -0.1368, -0.0393, -0.3343,\n",
       "        -0.1197, -0.2345, -0.0319, -0.0046, -0.2700,  0.0633, -0.0794, -0.2334,\n",
       "        -0.2527,  0.0007, -0.2348, -0.0353,  0.0057, -0.2295,  0.0348, -0.0424,\n",
       "        -0.2762, -0.0709, -0.0013,  0.0605,  0.0303, -0.1931,  0.0486,  0.0197,\n",
       "        -0.2239, -0.0374, -0.2874,  0.0325, -0.0511,  0.0595, -0.1365, -0.3406],\n",
       "       grad_fn=<IndexBackward0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_value(state, action):\n",
    "    #使用状态计算出动作的logits\n",
    "    #[b, 4] -> [b, 2]\n",
    "    value = model(state)\n",
    "\n",
    "    #根据实际使用的action取出每一个值\n",
    "    #这个值就是模型评估的在该状态下,执行动作的分数\n",
    "    #在执行动作前,显然并不知道会得到的反馈和next_state\n",
    "    #所以这里不能也不需要考虑next_state和reward\n",
    "    #[b, 2] -> [b]\n",
    "    value = value[range(64), action]\n",
    "\n",
    "    return value\n",
    "\n",
    "\n",
    "get_value(state, action)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.8519, 1.0000, 0.9941, 0.7699, 0.9947, 1.0274, 1.0034, 0.9538, 1.0023,\n",
       "        1.0193, 0.6662, 0.7861, 0.9602, 1.0340, 0.7604, 0.7128, 1.0369, 0.7331,\n",
       "        1.0056, 0.8307, 0.8136, 0.7718, 1.0042, 1.0000, 1.0290, 0.9952, 1.0000,\n",
       "        0.7701, 1.0000, 0.8163, 0.9235, 1.0000, 0.8340, 0.7184, 0.9374, 0.9615,\n",
       "        0.6846, 1.0348, 0.8688, 0.7189, 0.7010, 0.9654, 0.7193, 0.9292, 0.9687,\n",
       "        0.7251, 0.9983, 0.9201, 0.6763, 0.8743, 0.9652, 1.0339, 0.9904, 0.7586,\n",
       "        1.0215, 0.9773, 0.7299, 0.9267, 0.6667, 0.9952, 0.9065, 1.0338, 0.8174,\n",
       "        1.0000])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_target(reward, next_state, over):\n",
    "    #上面已经把模型认为的状态下执行动作的分数给评估出来了\n",
    "    #下面使用next_state和reward计算真实的分数\n",
    "    #针对一个状态,它到底应该多少分,可以使用以往模型积累的经验评估\n",
    "    #这也是没办法的办法,因为显然没有精确解,这里使用延迟更新的next_model评估\n",
    "\n",
    "    #使用next_state计算下一个状态的分数\n",
    "    #[b, 4] -> [b, 2]\n",
    "    with torch.no_grad():\n",
    "        target = model(next_state)\n",
    "\n",
    "    #取所有动作中分数最大的\n",
    "    #[b, 2] -> [b]\n",
    "    target = target.max(dim=1)[0]\n",
    "\n",
    "    #如果next_state已经游戏结束,则next_state的分数是0\n",
    "    #因为如果下一步已经游戏结束,显然不需要再继续玩下去,也就不需要考虑next_state了.\n",
    "    #[b]\n",
    "    for i in range(64):\n",
    "        if over[i]:\n",
    "            target[i] = 0\n",
    "\n",
    "    #下一个状态的分数乘以一个系数,相当于权重\n",
    "    #[b] * [b] -> [b]\n",
    "    target *= 0.98\n",
    "\n",
    "    #加上reward就是最终的分数\n",
    "    #[b] + [b] -> [b]\n",
    "    target += reward\n",
    "\n",
    "    return target\n",
    "\n",
    "\n",
    "get_target(reward, next_state, over)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10.0"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "\n",
    "\n",
    "def test(play):\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #记录反馈值的和,这个值越大越好\n",
    "    reward_sum = 0\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        state, reward, over, _ = env.step(action)\n",
    "        reward_sum += reward\n",
    "\n",
    "        #打印动画\n",
    "        if play:\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "test(play=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "OHoSU6uI-xIt",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 411 203 0 9.25\n",
      "50 10000 394 394 179.55\n",
      "100 10000 200 200 200.0\n",
      "150 10000 200 200 199.15\n",
      "200 10000 200 200 190.3\n",
      "250 10000 200 200 198.45\n",
      "300 10000 344 344 184.95\n",
      "350 10000 200 200 200.0\n",
      "400 10000 200 200 200.0\n",
      "450 10000 200 200 198.3\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #训练N次\n",
    "    for epoch in range(500):\n",
    "        #更新N条数据\n",
    "        update_count, drop_count = update_data()\n",
    "\n",
    "        #每次更新过数据后,学习N次\n",
    "        for i in range(200):\n",
    "            #采样一批数据\n",
    "            state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "            #计算一批样本的value和target\n",
    "            value = get_value(state, action)\n",
    "            target = get_target(reward, next_state, over)\n",
    "\n",
    "            #更新参数\n",
    "            loss = loss_fn(value, target)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "        if epoch % 50 == 0:\n",
    "            test_result = sum([test(play=False) for _ in range(20)]) / 20\n",
    "            print(epoch, len(datas), update_count, drop_count, test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAF7CAYAAAD4/3BBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAmJklEQVR4nO3df3DU9YH/8deGJCu/dmOAZJOSIAoFIgR7gGHP1rNHSoBI5YzfUctJ2mNg5BKnEEoxPSrFuzEc3lzVnsLc3J14M0ZaOqInFbw0SDhr+GFKjoCaE442WLIJymU3BNn8en//8Mvne6uobFjY98bnY+Yzk93Pez/73vdkJs/Z/ewnLmOMEQAAgEWS4j0BAACATyJQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHXiGihPP/20brjhBl133XUqKCjQwYMH4zkdAABgibgFys9//nNVVFRo/fr1+u1vf6vp06erqKhI7e3t8ZoSAACwhCte/yywoKBAs2bN0j/8wz9Ikvr7+5WTk6OHHnpIDz/8cDymBAAALJEcjyft7u5WQ0ODKisrnfuSkpJUWFio+vr6T40Ph8MKh8PO7f7+fp09e1ajRo2Sy+W6JnMGAABXxhijzs5OZWdnKynp8z/EiUugfPDBB+rr61NmZmbE/ZmZmXr33Xc/Nb6qqkobNmy4VtMDAABX0alTpzR27NjPHROXQIlWZWWlKioqnNvBYFC5ubk6deqUPB5PHGcGAAAuVygUUk5OjkaOHPmFY+MSKKNHj9aQIUPU1tYWcX9bW5t8Pt+nxrvdbrnd7k/d7/F4CBQAABLM5ZyeEZdv8aSmpmrGjBmqra117uvv71dtba38fn88pgQAACwSt494KioqVFpaqpkzZ+rWW2/VE088oa6uLn3ve9+L15QAAIAl4hYo9957r86cOaNHHnlEgUBAt9xyi3bv3v2pE2cBAMCXT9yug3IlQqGQvF6vgsEg56AAAJAgovn7zf/iAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1Yh4oP/nJT+RyuSK2yZMnO/svXLigsrIyjRo1SiNGjFBJSYna2tpiPQ0AAJDArso7KDfffLNaW1ud7Y033nD2rVq1Sq+88oq2b9+uuro6nT59WnfffffVmAYAAEhQyVfloMnJ8vl8n7o/GAzqn//5n1VdXa0//dM/lSQ9++yzmjJlivbv36/Zs2dfjekAAIAEc1XeQXnvvfeUnZ2tG2+8UYsXL1ZLS4skqaGhQT09PSosLHTGTp48Wbm5uaqvr//M44XDYYVCoYgNAAAMXjEPlIKCAm3dulW7d+/W5s2bdfLkSX3jG99QZ2enAoGAUlNTlZaWFvGYzMxMBQKBzzxmVVWVvF6vs+Xk5MR62gAAwCIx/4hn/vz5zs/5+fkqKCjQuHHj9Itf/EJDhw4d0DErKytVUVHh3A6FQkQKAACD2FX/mnFaWpq++tWv6vjx4/L5fOru7lZHR0fEmLa2tkues3KR2+2Wx+OJ2AAAwOB11QPl3LlzOnHihLKysjRjxgylpKSotrbW2d/c3KyWlhb5/f6rPRUAAJAgYv4Rzw9+8AMtXLhQ48aN0+nTp7V+/XoNGTJE999/v7xer5YuXaqKigqlp6fL4/HooYcekt/v5xs8AADAEfNAef/993X//ffrww8/1JgxY/T1r39d+/fv15gxYyRJP/3pT5WUlKSSkhKFw2EVFRXpmWeeifU0AABAAnMZY0y8JxGtUCgkr9erYDDI+SgAACSIaP5+8794AACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFgn6kDZt2+fFi5cqOzsbLlcLr300ksR+40xeuSRR5SVlaWhQ4eqsLBQ7733XsSYs2fPavHixfJ4PEpLS9PSpUt17ty5K3ohAABg8Ig6ULq6ujR9+nQ9/fTTl9y/adMmPfXUU9qyZYsOHDig4cOHq6ioSBcuXHDGLF68WMeOHVNNTY127typffv2afny5QN/FQAAYFBxGWPMgB/scmnHjh1atGiRpI/fPcnOztbq1av1gx/8QJIUDAaVmZmprVu36r777tM777yjvLw8HTp0SDNnzpQk7d69WwsWLND777+v7OzsL3zeUCgkr9erYDAoj8cz0OkDAIBrKJq/3zE9B+XkyZMKBAIqLCx07vN6vSooKFB9fb0kqb6+XmlpaU6cSFJhYaGSkpJ04MCBSx43HA4rFApFbAAAYPCKaaAEAgFJUmZmZsT9mZmZzr5AIKCMjIyI/cnJyUpPT3fGfFJVVZW8Xq+z5eTkxHLaAADAMgnxLZ7KykoFg0FnO3XqVLynBAAArqKYBorP55MktbW1Rdzf1tbm7PP5fGpvb4/Y39vbq7NnzzpjPsntdsvj8URsAABg8IppoIwfP14+n0+1tbXOfaFQSAcOHJDf75ck+f1+dXR0qKGhwRmzZ88e9ff3q6CgIJbTAQAACSo52gecO3dOx48fd26fPHlSjY2NSk9PV25urlauXKm/+Zu/0cSJEzV+/Hj9+Mc/VnZ2tvNNnylTpmjevHlatmyZtmzZop6eHpWXl+u+++67rG/wAACAwS/qQHnrrbf0zW9+07ldUVEhSSotLdXWrVv1wx/+UF1dXVq+fLk6Ojr09a9/Xbt379Z1113nPOb5559XeXm55syZo6SkJJWUlOipp56KwcsBAACDwRVdByVeuA4KAACJJ27XQQEAAIgFAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWCfqQNm3b58WLlyo7OxsuVwuvfTSSxH7v/vd78rlckVs8+bNixhz9uxZLV68WB6PR2lpaVq6dKnOnTt3RS8EAAAMHlEHSldXl6ZPn66nn376M8fMmzdPra2tzvbCCy9E7F+8eLGOHTummpoa7dy5U/v27dPy5cujnz0AABiUkqN9wPz58zV//vzPHeN2u+Xz+S6575133tHu3bt16NAhzZw5U5L0s5/9TAsWLNDf/d3fKTs7O9opAQCAQeaqnIOyd+9eZWRkaNKkSVqxYoU+/PBDZ199fb3S0tKcOJGkwsJCJSUl6cCBA5c8XjgcVigUitgAAMDgFfNAmTdvnv71X/9VtbW1+tu//VvV1dVp/vz56uvrkyQFAgFlZGREPCY5OVnp6ekKBAKXPGZVVZW8Xq+z5eTkxHraAADAIlF/xPNF7rvvPufnadOmKT8/XzfddJP27t2rOXPmDOiYlZWVqqiocG6HQiEiBQCAQeyqf834xhtv1OjRo3X8+HFJks/nU3t7e8SY3t5enT179jPPW3G73fJ4PBEbAAAYvK56oLz//vv68MMPlZWVJUny+/3q6OhQQ0ODM2bPnj3q7+9XQUHB1Z4OAABIAFF/xHPu3Dnn3RBJOnnypBobG5Wenq709HRt2LBBJSUl8vl8OnHihH74wx9qwoQJKioqkiRNmTJF8+bN07Jly7Rlyxb19PSovLxc9913H9/gAQAAkiSXMcZE84C9e/fqm9/85qfuLy0t1ebNm7Vo0SIdPnxYHR0dys7O1ty5c/XXf/3XyszMdMaePXtW5eXleuWVV5SUlKSSkhI99dRTGjFixGXNIRQKyev1KhgM8nEPAAAJIpq/31EHig0IFAAAEk80f7/5XzwAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsA6BAgAArEOgAAAA6xAoAADAOgQKAACwDoECAACsQ6AAAADrECgAAMA6BAoAALAOgQIAAKxDoAAAAOsQKAAAwDoECgAAsE5UgVJVVaVZs2Zp5MiRysjI0KJFi9Tc3Bwx5sKFCyorK9OoUaM0YsQIlZSUqK2tLWJMS0uLiouLNWzYMGVkZGjNmjXq7e298lcDAAAGhagCpa6uTmVlZdq/f79qamrU09OjuXPnqquryxmzatUqvfLKK9q+fbvq6up0+vRp3X333c7+vr4+FRcXq7u7W2+++aaee+45bd26VY888kjsXhUAAEhoLmOMGeiDz5w5o4yMDNXV1en2229XMBjUmDFjVF1drXvuuUeS9O6772rKlCmqr6/X7NmztWvXLt155506ffq0MjMzJUlbtmzR2rVrdebMGaWmpn7h84ZCIXm9XgWDQXk8noFOHwAAXEPR/P2+onNQgsGgJCk9PV2S1NDQoJ6eHhUWFjpjJk+erNzcXNXX10uS6uvrNW3aNCdOJKmoqEihUEjHjh275POEw2GFQqGIDQAADF4DDpT+/n6tXLlSt912m6ZOnSpJCgQCSk1NVVpaWsTYzMxMBQIBZ8z/jpOL+y/uu5Sqqip5vV5ny8nJGei0AQBAAhhwoJSVleno0aPatm1bLOdzSZWVlQoGg8526tSpq/6cAAAgfpIH8qDy8nLt3LlT+/bt09ixY537fT6furu71dHREfEuSltbm3w+nzPm4MGDEce7+C2fi2M+ye12y+12D2SqAAAgAUX1DooxRuXl5dqxY4f27Nmj8ePHR+yfMWOGUlJSVFtb69zX3NyslpYW+f1+SZLf71dTU5Pa29udMTU1NfJ4PMrLy7uS1wIAAAaJqN5BKSsrU3V1tV5++WWNHDnSOWfE6/Vq6NCh8nq9Wrp0qSoqKpSeni6Px6OHHnpIfr9fs2fPliTNnTtXeXl5euCBB7Rp0yYFAgGtW7dOZWVlvEsCAAAkRfk1Y5fLdcn7n332WX33u9+V9PGF2lavXq0XXnhB4XBYRUVFeuaZZyI+vvn973+vFStWaO/evRo+fLhKS0u1ceNGJSdfXi/xNWMAABJPNH+/r+g6KPFCoAAAkHiu2XVQAAAArgYCBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWiSpQqqqqNGvWLI0cOVIZGRlatGiRmpubI8bccccdcrlcEduDDz4YMaalpUXFxcUaNmyYMjIytGbNGvX29l75qwEAAINCcjSD6+rqVFZWplmzZqm3t1c/+tGPNHfuXL399tsaPny4M27ZsmV69NFHndvDhg1zfu7r61NxcbF8Pp/efPNNtba2asmSJUpJSdFjjz0Wg5cEAAASncsYYwb64DNnzigjI0N1dXW6/fbbJX38Dsott9yiJ5544pKP2bVrl+68806dPn1amZmZkqQtW7Zo7dq1OnPmjFJTU7/weUOhkLxer4LBoDwez0CnDwAArqFo/n5f0TkowWBQkpSenh5x//PPP6/Ro0dr6tSpqqys1Pnz55199fX1mjZtmhMnklRUVKRQKKRjx45d8nnC4bBCoVDEBgAABq+oPuL53/r7+7Vy5Urddtttmjp1qnP/d77zHY0bN07Z2dk6cuSI1q5dq+bmZr344ouSpEAgEBEnkpzbgUDgks9VVVWlDRs2fHoOvd0DnT4AALDYgAOlrKxMR48e1RtvvBFx//Lly52fp02bpqysLM2ZM0cnTpzQTTfdNKDnqqysVEVFhXM7FAopJydHPR91Sho9oGMCAAB7DegjnvLycu3cuVOvv/66xo4d+7ljCwoKJEnHjx+XJPl8PrW1tUWMuXjb5/Nd8hhut1sejydik/T/AgUAAAw2UQWKMUbl5eXasWOH9uzZo/Hjx3/hYxobGyVJWVlZkiS/36+mpia1t7c7Y2pqauTxeJSXlxfNdNTzEeeiAAAwGEX1EU9ZWZmqq6v18ssva+TIkc45I16vV0OHDtWJEydUXV2tBQsWaNSoUTpy5IhWrVql22+/Xfn5+ZKkuXPnKi8vTw888IA2bdqkQCCgdevWqaysTG63O6rJX+ho++JBAAAg4UT1DsrmzZsVDAZ1xx13KCsry9l+/vOfS5JSU1P161//WnPnztXkyZO1evVqlZSU6JVXXnGOMWTIEO3cuVNDhgyR3+/Xn//5n2vJkiUR1025XOc/aIn6MQAAwH5RvYPyRZdMycnJUV1d3RceZ9y4cXr11VejeWoAAPAlkuD/i8d8YTQBAIDEk9CB0t/Xo77uj+I9DQAAEGOJHSg93QQKAACDUGIHSm9YveHzXzwQAAAklIQOlL6ebvV1EygAAAw2CR0o4c4z6jrzu3hPAwAAxFhCB4qMkenvj/csAABAjCV2oAAAgEEp4QPF9HbL9PfFexoAACCGEj5QesPn1d/bHe9pAACAGBoEgdKl/t6eeE8DAADEUMIHSrjzA66FAgDAIJPwgdLV9t/qPnc23tMAAAAxlPCBAgAABh8CBQAAWGeQBEq/jDHxngQAAIiRQREoPedDkuGKsgAADBaDIlC6zwdlCBQAAAaNQREoPeeDMv18xAMAwGCRHO8JxMIHzW8qM79IxjXw3nK5XBoyZEgMZwUAAAZqUARKf88F3XDDOH0Y+mjAx1i4cKFefPHFGM4KAAAM1KAIFEnq7e1Vb2/vgB/f18c/HAQAwBYJfQ7KqQuT1NXnifc0AABAjCV0oLx3fqYaO+co2DtKI4elxns6AAAgRhI6UPpMikK9Y3QoWKz067PiPR0AABAjCR0oF/WY6zT///xjvKcBAABiZFAEivTx14QBAMDgMGgCBQAADB4ECgAAsM6gCJQk9eq2NC6yBgDAYBFVoGzevFn5+fnyeDzyeDzy+/3atWuXs//ChQsqKyvTqFGjNGLECJWUlKitrS3iGC0tLSouLtawYcOUkZGhNWvWDPgCa6a/R8n9Z3WjfqHTrScGdAwAAGCfqK4kO3bsWG3cuFETJ06UMUbPPfec7rrrLh0+fFg333yzVq1apV/96lfavn27vF6vysvLdffdd+s3v/mNpI+v1lpcXCyfz6c333xTra2tWrJkiVJSUvTYY49FPfnO1hq1/0+Tdp94W/Vvvx/14wEAgJ1cxpgr+jfA6enpevzxx3XPPfdozJgxqq6u1j333CNJevfddzVlyhTV19dr9uzZ2rVrl+68806dPn1amZmZkqQtW7Zo7dq1OnPmjFJTL+9ia6FQSF6vVzdmX68/nAkp3HPll6kfN26cioqKrvg4AADg0rq7u7V161YFg0F5PJ9/JfgB/y+evr4+bd++XV1dXfL7/WpoaFBPT48KCwudMZMnT1Zubq4TKPX19Zo2bZoTJ5JUVFSkFStW6NixY/ra1752yecKh8MKh8PO7VAoJEn679P/M9Dpf0pubq6WLl0as+MBAIBI586d09atWy9rbNSB0tTUJL/frwsXLmjEiBHasWOH8vLy1NjYqNTUVKWlpUWMz8zMVCAQkCQFAoGIOLm4/+K+z1JVVaUNGzZEO9WoXH/99br11luv6nMAAPBldvENhssR9bd4Jk2apMbGRh04cEArVqxQaWmp3n777WgPE5XKykoFg0FnO3Xq1FV9PgAAEF9Rv4OSmpqqCRMmSJJmzJihQ4cO6cknn9S9996r7u5udXR0RLyL0tbWJp/PJ0ny+Xw6ePBgxPEufsvn4phLcbvdcrvd0U4VAAAkqCu+Dkp/f7/C4bBmzJihlJQU1dbWOvuam5vV0tIiv98vSfL7/WpqalJ7e7szpqamRh6PR3l5eVc6FQAAMEhE9Q5KZWWl5s+fr9zcXHV2dqq6ulp79+7Va6+9Jq/Xq6VLl6qiokLp6enyeDx66KGH5Pf7NXv2bEnS3LlzlZeXpwceeECbNm1SIBDQunXrVFZWxjskAADAEVWgtLe3a8mSJWptbZXX61V+fr5ee+01fetb35Ik/fSnP1VSUpJKSkoUDodVVFSkZ555xnn8kCFDtHPnTq1YsUJ+v1/Dhw9XaWmpHn300di+KgAAkNCu+Doo8XDxOiix9O1vf1svv/xyTI8JAAD+v4t/vy/nOiiD4n/xAACAwYVAAQAA1iFQAACAdQgUAABgnQH/Lx4bFBcXKyUlJSbH4jL3AADYI6EDpbq6+gvPAgYAAImHj3gAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWiSpQNm/erPz8fHk8Hnk8Hvn9fu3atcvZf8cdd8jlckVsDz74YMQxWlpaVFxcrGHDhikjI0Nr1qxRb29vbF4NAAAYFJKjGTx27Fht3LhREydOlDFGzz33nO666y4dPnxYN998syRp2bJlevTRR53HDBs2zPm5r69PxcXF8vl8evPNN9Xa2qolS5YoJSVFjz32WIxeEgAASHQuY4y5kgOkp6fr8ccf19KlS3XHHXfolltu0RNPPHHJsbt27dKdd96p06dPKzMzU5K0ZcsWrV27VmfOnFFqauplPWcoFJLX61UwGJTH47mS6QMAgGskmr/fAz4Hpa+vT9u2bVNXV5f8fr9z//PPP6/Ro0dr6tSpqqys1Pnz55199fX1mjZtmhMnklRUVKRQKKRjx4595nOFw2GFQqGIDQAADF5RfcQjSU1NTfL7/bpw4YJGjBihHTt2KC8vT5L0ne98R+PGjVN2draOHDmitWvXqrm5WS+++KIkKRAIRMSJJOd2IBD4zOesqqrShg0bop0qAABIUFEHyqRJk9TY2KhgMKhf/vKXKi0tVV1dnfLy8rR8+XJn3LRp05SVlaU5c+boxIkTuummmwY8ycrKSlVUVDi3Q6GQcnJyBnw8AABgt6g/4klNTdWECRM0Y8YMVVVVafr06XryyScvObagoECSdPz4cUmSz+dTW1tbxJiLt30+32c+p9vtdr45dHEDAACD1xVfB6W/v1/hcPiS+xobGyVJWVlZkiS/36+mpia1t7c7Y2pqauTxeJyPiQAAAKL6iKeyslLz589Xbm6uOjs7VV1drb179+q1117TiRMnVF1drQULFmjUqFE6cuSIVq1apdtvv135+fmSpLlz5yovL08PPPCANm3apEAgoHXr1qmsrExut/uqvEAAAJB4ogqU9vZ2LVmyRK2trfJ6vcrPz9drr72mb33rWzp16pR+/etf64knnlBXV5dycnJUUlKidevWOY8fMmSIdu7cqRUrVsjv92v48OEqLS2NuG4KAADAFV8HJR64DgoAAInnmlwHBQAA4GohUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGAdAgUAAFiHQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWSY73BAbCGCNJCoVCcZ4JAAC4XBf/bl/8O/55EjJQOjs7JUk5OTlxngkAAIhWZ2envF7v545xmcvJGMv09/erublZeXl5OnXqlDweT7ynlLBCoZBycnJYxxhgLWOHtYwN1jF2WMvYMMaos7NT2dnZSkr6/LNMEvIdlKSkJH3lK1+RJHk8Hn5ZYoB1jB3WMnZYy9hgHWOHtbxyX/TOyUWcJAsAAKxDoAAAAOskbKC43W6tX79ebrc73lNJaKxj7LCWscNaxgbrGDus5bWXkCfJAgCAwS1h30EBAACDF4ECAACsQ6AAAADrECgAAMA6CRkoTz/9tG644QZdd911Kigo0MGDB+M9Jevs27dPCxcuVHZ2tlwul1566aWI/cYYPfLII8rKytLQoUNVWFio9957L2LM2bNntXjxYnk8HqWlpWnp0qU6d+7cNXwV8VdVVaVZs2Zp5MiRysjI0KJFi9Tc3Bwx5sKFCyorK9OoUaM0YsQIlZSUqK2tLWJMS0uLiouLNWzYMGVkZGjNmjXq7e29li8lrjZv3qz8/HznIld+v1+7du1y9rOGA7dx40a5XC6tXLnSuY/1vDw/+clP5HK5IrbJkyc7+1nHODMJZtu2bSY1NdX8y7/8izl27JhZtmyZSUtLM21tbfGemlVeffVV81d/9VfmxRdfNJLMjh07IvZv3LjReL1e89JLL5n//M//NN/+9rfN+PHjzUcffeSMmTdvnpk+fbrZv3+/+Y//+A8zYcIEc//991/jVxJfRUVF5tlnnzVHjx41jY2NZsGCBSY3N9ecO3fOGfPggw+anJwcU1tba9566y0ze/Zs88d//MfO/t7eXjN16lRTWFhoDh8+bF599VUzevRoU1lZGY+XFBf/9m//Zn71q1+Z//qv/zLNzc3mRz/6kUlJSTFHjx41xrCGA3Xw4EFzww03mPz8fPP973/fuZ/1vDzr1683N998s2ltbXW2M2fOOPtZx/hKuEC59dZbTVlZmXO7r6/PZGdnm6qqqjjOym6fDJT+/n7j8/nM448/7tzX0dFh3G63eeGFF4wxxrz99ttGkjl06JAzZteuXcblcpk//OEP12zutmlvbzeSTF1dnTHm43VLSUkx27dvd8a88847RpKpr683xnwci0lJSSYQCDhjNm/ebDwejwmHw9f2BVjk+uuvN//0T//EGg5QZ2enmThxoqmpqTF/8id/4gQK63n51q9fb6ZPn37Jfaxj/CXURzzd3d1qaGhQYWGhc19SUpIKCwtVX18fx5kllpMnTyoQCESso9frVUFBgbOO9fX1SktL08yZM50xhYWFSkpK0oEDB675nG0RDAYlSenp6ZKkhoYG9fT0RKzl5MmTlZubG7GW06ZNU2ZmpjOmqKhIoVBIx44du4azt0NfX5+2bdumrq4u+f1+1nCAysrKVFxcHLFuEr+T0XrvvfeUnZ2tG2+8UYsXL1ZLS4sk1tEGCfXPAj/44AP19fVF/DJIUmZmpt599904zSrxBAIBSbrkOl7cFwgElJGREbE/OTlZ6enpzpgvm/7+fq1cuVK33Xabpk6dKunjdUpNTVVaWlrE2E+u5aXW+uK+L4umpib5/X5duHBBI0aM0I4dO5SXl6fGxkbWMErbtm3Tb3/7Wx06dOhT+/idvHwFBQXaunWrJk2apNbWVm3YsEHf+MY3dPToUdbRAgkVKEA8lZWV6ejRo3rjjTfiPZWENGnSJDU2NioYDOqXv/ylSktLVVdXF+9pJZxTp07p+9//vmpqanTdddfFezoJbf78+c7P+fn5Kigo0Lhx4/SLX/xCQ4cOjePMICXYt3hGjx6tIUOGfOos6ra2Nvl8vjjNKvFcXKvPW0efz6f29vaI/b29vTp79uyXcq3Ly8u1c+dOvf766xo7dqxzv8/nU3d3tzo6OiLGf3ItL7XWF/d9WaSmpmrChAmaMWOGqqqqNH36dD355JOsYZQaGhrU3t6uP/qjP1JycrKSk5NVV1enp556SsnJycrMzGQ9BygtLU1f/epXdfz4cX4vLZBQgZKamqoZM2aotrbWua+/v1+1tbXy+/1xnFliGT9+vHw+X8Q6hkIhHThwwFlHv9+vjo4ONTQ0OGP27Nmj/v5+FRQUXPM5x4sxRuXl5dqxY4f27Nmj8ePHR+yfMWOGUlJSItayublZLS0tEWvZ1NQUEXw1NTXyeDzKy8u7Ni/EQv39/QqHw6xhlObMmaOmpiY1NjY628yZM7V48WLnZ9ZzYM6dO6cTJ04oKyuL30sbxPss3Wht27bNuN1us3XrVvP222+b5cuXm7S0tIizqPHxGf6HDx82hw8fNpLM3//935vDhw+b3//+98aYj79mnJaWZl5++WVz5MgRc9ddd13ya8Zf+9rXzIEDB8wbb7xhJk6c+KX7mvGKFSuM1+s1e/fujfgq4vnz550xDz74oMnNzTV79uwxb731lvH7/cbv9zv7L34Vce7cuaaxsdHs3r3bjBkz5kv1VcSHH37Y1NXVmZMnT5ojR46Yhx9+2LhcLvPv//7vxhjW8Er972/xGMN6Xq7Vq1ebvXv3mpMnT5rf/OY3prCw0IwePdq0t7cbY1jHeEu4QDHGmJ/97GcmNzfXpKammltvvdXs378/3lOyzuuvv24kfWorLS01xnz8VeMf//jHJjMz07jdbjNnzhzT3NwccYwPP/zQ3H///WbEiBHG4/GY733ve6azszMOryZ+LrWGksyzzz7rjPnoo4/MX/7lX5rrr7/eDBs2zPzZn/2ZaW1tjTjO7373OzN//nwzdOhQM3r0aLN69WrT09NzjV9N/PzFX/yFGTdunElNTTVjxowxc+bMceLEGNbwSn0yUFjPy3PvvfearKwsk5qaar7yla+Ye++91xw/ftzZzzrGl8sYY+Lz3g0AAMClJdQ5KAAA4MuBQAEAANYhUAAAgHUIFAAAYB0CBQAAWIdAAQAA1iFQAACAdQgUAABgHQIFAABYh0ABAADWIVAAAIB1CBQAAGCd/wt6HhHvaZjX7QAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "200.0"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(play=True)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
